% Matlab file for plotting spiking data
close all
fontname = ("Times New Roman");

num_RGC = 1;
runtime = 1500;
sample_time = 0.1;                  % in ms
sample_freq = 1000/sample_time;

RGC_param = readmatrix("Matrices/RGC_parameters.csv");
t = readmatrix("Cell responses/t.csv");
t = t(2:15001);
type = RGC_param(:, 1);

ON_RGCs = find(type == 1);
OFF_RGCs = 1;

directory = "Cell responses/RGCs/";
S = dir(fullfile(directory, "*.csv"));

for k = 1:num_RGC
    F = fullfile(directory, S(k).name);
    S(k).data = csvread(F);
    S(k).data = S(k).data(5001:20000);
    [S(k).spike_amps, S(k).spike_timing] = findpeaks(S(k).data, 'MINPEAKHEIGHT', -100, 'MinPeakDistance',50);
    num_spikes(k) = size(S(k).spike_timing, 1);
end

fig = figure;
xlim([0 15001]);
% line([3000 3000], [0 70], 'Color', 'k', 'LineStyle','--');

% count = 1;
% for k = 1:size(ON_RGCs, 1)
%     if((num_spikes(ON_RGCs(k))) > 1)
%         for i = 1:num_spikes(ON_RGCs(k))
%             line([S(ON_RGCs(k)).spike_timing(i) S(ON_RGCs(k)).spike_timing(i)], [(count-0.5) (count+0.5)], 'Color', 'b');
%         end
%         count = count + 1;
%     end
% end
% ylim([0 count]);
% axis off;
% fig.Position = ([600 300 300 200]);

fig = figure;
xlim([0 15001]);
% line([3000 3000], [0 70], 'Color', 'k', 'LineStyle','--');

count = 1;
for k = 1:size(OFF_RGCs, 1)
    if((num_spikes(OFF_RGCs(k))) > 1)
        for i = 1:num_spikes(OFF_RGCs(k))
            line([S(OFF_RGCs(k)).spike_timing(i) S(OFF_RGCs(k)).spike_timing(i)], [(count-0.5) (count+0.5)], 'Color', 'r');
        end
        count = count + 1;
    end
end
ylim([0 count]);
axis off;
fig.Position = ([600 300 300 200]);

% %%
% fig = figure;
% hold on
% colormap gray
% xlim([0 1100]);
% ylim([-1 0]);
% % initial off
% x = [0 0 300 300];
% y = [-1 0 0 -1];
% c = [0 0 0 0];
% patch(x, y, c, "EdgeAlpha", 0);
% 
% 
% x = [300 300 500 500];
% y = [-1 0 0 -1];
% c = [1 1 1 1];
% patch(x, y, c, "EdgeAlpha", 0);
% 
% x = [500 500 1100 1100];
% y = [-1 0 0 -1];
% c = [0 0 0 0];
% patch(x, y, c, "EdgeAlpha", 0);
% hold off

%%
% Rod = readmatrix("Cell responses/PRs/PR_upper_214.csv");
% Cone = readmatrix("Cell responses/PRs/PR_lower_215.csv");
% HZ = readmatrix("Cell responses/HZs/HZ_lower_29.csv");
% RBC = readmatrix("Cell responses/BIPs/BIP_lower_26.csv");
% ONBC = readmatrix("Cell responses/BIPs/BIP_lower_28.csv");
% OFFBC = readmatrix("Cell responses/BIPs/BIP_lower_41.csv");
AC = readmatrix("Cell responses/AIIs/AII_lower_5.csv");
OFFRGC = readmatrix("Cell responses/RGCs/RGC_soma_0.csv");

% Rod = Rod(5002:20001);
% Cone = Cone(5002:20001);
% HZ = HZ(5002:20001);
% RBC = RBC(5002:20001);
% ONBC = ONBC(5002:20001);
% OFFBC = OFFBC(5002:20001);
AC = AC(5001:20000);
OFFRGC = OFFRGC(5001:20000);


fig = figure;
tiledlayout(9,1, "TileSpacing", "none")

% nexttile
% plot(t, Rod, 'k');
% xlim([0 1100]);
% ylim([-80 -40]);
% axis off
% 
% nexttile
% plot(t, Cone, 'm');
% xlim([0 1500]);
% ylim([-80 -40]);
% axis off
% 
% nexttile
% plot(t, HZ, 'k');
% xlim([0 1500]);
% ylim([-80 -30]);
% axis off
% 
% nexttile
% hold on;
% line([1500 1500], [-30 -20], 'Color', 'k');
% plot(t, RBC, 'k');
% xlim([0 1500]);
% ylim([-40 -20]);
% axis off
% 
% nexttile
% hold on;
% line([1500 1500], [-30 -20], 'Color', 'k');
% line([1400 1500], [-20 -20], 'Color', 'k');
% plot(t, ONBC, 'b');
% xlim([0 1500]);
% ylim([-40 -20]);
% axis off
% 
% nexttile
% hold on;
% line([1500 1500], [-60 -50], 'Color', 'k');
% plot(t, OFFBC, 'r');
% xlim([0 1500]);
% ylim([-60 -30]);
% axis off

nexttile
hold on;
line([1400 1500], [-40 -40], 'Color', 'k');
line([1500 1500], [-50 -40], 'Color', 'k');
plot(t, AC, 'k');
xlim([0 1500]);
ylim([-70 -40]);
axis off

nexttile
hold on;
line([1500 1500], [-40 10], 'Color', 'k');
plot(t, OFFRGC, 'r');
xlim([0 1500]);
ylim([-80 30]);
axis off

fig.Position = ([400 200 300 600]);

%% 

bin_time = 1;     % ms
num_bins = runtime/bin_time;
bin_samples = bin_time/sample_time;

spike_timing_within = [];
for k = 1:num_RGC
    for i = 1:(num_bins - 1)
        count = 0;
        for j = 1:size(S(k).spike_timing, 1)
            if(S(k).spike_timing(j) > (i-1)*bin_samples && S(k).spike_timing(j) <= i*bin_samples)
                spike_timing_within(count+1) = S(k).spike_timing(j);
                count = count + 1;
            end
            total = 0;
            for l = 1:count-1
                total = total + (spike_timing_within(l+1) - spike_timing_within(l));
            end
        end
        S(k).spikecount(i) = count;
        spike_period = ((total*sample_time)/count);
        if(spike_period == 0)
            S(k).spikefreq(i) = 0;
        else
            S(k).spikefreq(i) = 1000/((total*sample_time)/count);
        end
    end
end

ON_SpikeCount = zeros((num_bins-1), 1);
for k = 1:size(ON_RGCs, 1)
    if((num_spikes(ON_RGCs(k))) > 1)
        for i = 1:(num_bins-1)
            ON_SpikeCount(i) = ON_SpikeCount(i) + S(ON_RGCs(k)).spikecount(i);
        end
    end
end

fig = figure;
hold on;
%line([20 20], [0 70], 'Color', 'k', 'LineStyle','--');
%title("ON RGC Spike Count")
x = 1:1:(num_bins-1);
bar(x, (ON_SpikeCount/size(ON_RGCs, 1))/(bin_time/1000),'blue')
fig.Position = ([200 300 300 80]);
ylim([0, 1]);
box off

OFF_SpikeCount = zeros((num_bins-1), 1);
for k = 1:size(OFF_RGCs, 1)
    if((num_spikes(OFF_RGCs(k))) > 1)
        for i = 1:(num_bins-1)
            OFF_SpikeCount(i) = OFF_SpikeCount(i) + S(OFF_RGCs(k)).spikecount(i);
        end
    end
end

fig = figure;
hold on;
%line([20 20], [0 30], 'Color', 'k', 'LineStyle','--');
%title("OFF RGC Spike Count")
x = 1:1:(num_bins-1);
bar(x, (OFF_SpikeCount/size(OFF_RGCs, 1))/(bin_time/1000),'r')
fig.Position = ([200 300 300 80]);
ylim([0, 1]);
box off

%% Calculate and plot the frequency response

windowSize = 50; 
b = (1/windowSize)*ones(1,windowSize);
filtered_OFF = filter(b, 1, OFF_SpikeCount);
filtered_ON = filter(b, 1, ON_SpikeCount);

fig = figure;
hold on
plot(filtered_ON, 'b');
plot(filtered_OFF, 'r');
fig.Position = ([900 300 300 120]);

% remove first 100ms of signal
unstable = 100;
filtered_OFF = filtered_OFF(unstable:end);
filtered_ON = filtered_ON(unstable:end);

% normalise around 0
filtered_OFF = filtered_OFF-(min(filtered_OFF)+(max(filtered_OFF-min(filtered_OFF))/2));
filtered_ON = filtered_ON-(min(filtered_ON)+(max(filtered_ON-min(filtered_ON))/2));

OFFRGC = OFFRGC - (-110.3 + (110.3 - 77.52)/2);

fig = figure;
n = 2^nextpow2(runtime);
n = n-unstable;
FFT_OFF = fft(OFFRGC, n);
FFT_ON = fft(filtered_ON, n);

P2_OFF = abs(FFT_OFF/n);
P1_OFF = P2_OFF(1:n/2+1);
P1_OFF(2:end-1) = 2*P1_OFF(2:end-1);

P2_ON = abs(FFT_ON/n);
P1_ON = P2_ON(1:n/2+1);
P1_ON(2:end-1) = 2*P1_ON(2:end-1);

x_vector = (sample_freq/n)*(0:(n/2));
hold on
plot(x_vector, P1_OFF, 'r');
%plot(x_vector, P1_ON, 'b');
xlim([0,30]);
fig.Position = ([900 300 300 120]);